import torch
import random
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import json
import string
from collections import Counter, defaultdict
from scipy.special import softmax
import os
from sklearn.cluster import KMeans

def get_icl_examples(train_dataset="arc-easy", test_dataset="arc-easy", emb=None, shuffle_seed=0, k=4, method="base", dp_choice="knn", subset_size=100, metric="cosine_similarity", if_qwa=False, if_self_emb=False, model_name="llama-3.1-8b-instruction", permutation=1, random_shuffled=True):
    if if_qwa:
        print("using qwa mode")
        train_embs = torch.load("./"+"data"+"/"+train_dataset+"/"+train_dataset+"_"+"train"+"_qwa"+"_"+emb+".pt")
        test_embs = torch.load("./"+"data"+"/"+test_dataset+"/"+test_dataset+"_"+"test"+"_qwa"+"_"+emb+".pt")
    else:
        print("not using qwa mode")
        if emb == "bm25":
            emb_p = "all-roberta-large-v1"
            train_embs = torch.load("./"+"data"+"/"+train_dataset+"/"+train_dataset+"_"+"train"+"_"+emb_p+".pt")
            test_embs = torch.load("./"+"data"+"/"+test_dataset+"/"+test_dataset+"_"+"test"+"_"+emb_p+".pt")
        else:
            train_embs = torch.load("./"+"data"+"/"+train_dataset+"/"+train_dataset+"_"+"train"+"_"+emb+".pt")
            test_embs = torch.load("./"+"data"+"/"+test_dataset+"/"+test_dataset+"_"+"test"+"_"+emb+".pt")
            
    icl = ICLMethods(train_embs, test_embs, train_dataset, test_dataset, k, shuffle_seed, dp_choice, subset_size, metric, emb, if_qwa, model_name, random_shuffled)
    idx_mat = icl.get_method(method)
    idx_mat = np.array(idx_mat)
    if permutation == -1:
        idx_mat = np.fliplr(idx_mat)
    del train_embs, test_embs
    return idx_mat

class ICLMethods:
    def __init__(self, train_embs, test_embs, train_dataset, test_dataset, k=4, shuffle_seed=0, dp_choice="knn", subset_size=100, metric="cosine_similarity", emb="all-roberta-large-v1", if_qwa=False, model_name="llama-3.1-8b-instruction", random_shuffled=True):
        self.train_embs = train_embs
        self.test_embs = test_embs
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.k = k
        self.shuffle_seed = shuffle_seed
        self.dp_choice = dp_choice
        self.subset_size = subset_size
        self.metric = metric
        self.emb = emb
        self.if_qwa = if_qwa
        self.model_name = model_name
        self.random_shuffled = random_shuffled
        self.n_train = len(self.train_embs)
        self.n_test = len(self.test_embs)
        self.alpha = 0.5
        random.seed(self.shuffle_seed)
        np.random.seed(self.shuffle_seed)

        if(self.emb == "bm25"):
            self.metric_matrix_train = np.load(f'./data/{self.test_dataset}/{self.test_dataset}_train_bm25_scores.npy')
            self.metric_matrix_test = np.load(f'./data/{self.test_dataset}/{self.test_dataset}_test_bm25_scores.npy')
        elif(self.emb.startswith("bert_score")):
            if(not self.if_qwa):
                self.metric_matrix_train = np.load(f"./data/{self.test_dataset}/{self.test_dataset}_train_{self.emb}.npy")
                self.metric_matrix_test = np.load(f'./data/{self.test_dataset}/{self.test_dataset}_{self.emb}.npy')
            else:
                self.metric_matrix_train = np.load(f"./data/{self.test_dataset}/{self.test_dataset}_train_qwa_{self.emb}.npy")
                self.metric_matrix_test = np.load(f'./data/{self.test_dataset}/{self.test_dataset}_test_qwa_{self.emb}.npy')
        else: 
            self.metric_matrix_train = []
            self.metric_matrix_train = self.metric_calculate(self.train_embs, self.train_embs, self.metric)
            self.metric_matrix_train = np.asarray(self.metric_matrix_train)
            self.metric_matrix_test = self.metric_calculate(self.test_embs, self.train_embs, self.metric)
            self.metric_matrix_test = np.asarray(self.metric_matrix_test)            
        np.fill_diagonal(self.metric_matrix_train, -1)

    def get_method(self, method):
        if method == "diversity":
            idx_mat = self.get_diversity_data_idx()
        elif method.startswith("diversity_") and method.split("_")[1].isdigit():
            subset_size_token = method.split("_")[1]
            self.subset_size = int(subset_size_token)
            idx_mat = self.get_diversity_data_idx()
        elif method.startswith("knn_s3_"):
            if "_" in method and len(method.split("_")) > 2  and method.split("_")[2].isdigit():
                subset_size_token = method.split("_")[2]
                self.subset_size = int(subset_size_token)
            idx_mat = self.get_knn_s3_data_idx()
        elif method == "random":
            idx_mat = self.get_random_data_idx()
        elif method == "knn":
            idx_mat = self.get_knn_data_idx()
        elif "knn_diversity" in method:
            if "_" in method and len(method.split("_")) > 2:
                self.alpha = float(method.split("_")[2])
            idx_mat = self.get_knn_diversity_data_idx()
        elif method == "zero_shot":
            idx_mat = self.zero_shot()
        elif method == "greedy_diversity":
            idx_mat = self.get_greedy_diversity_data_idx()
        elif method == "k_means":
            idx_mat = self.get_k_means_data_idx()
        elif method == "better_diversity":
            idx_mat = self.get_better_diversity_data_idx()
        elif method == "knn_fixed":
            idx_mat = self.get_knn_fixed_data_idx()
        elif method == "diversity_document":
            idx_mat = self.get_diversity_document_data_idx()
        elif method == "diversity_document_query":
            idx_mat = self.get_diversity_document_query_data_idx()
        else:
            raise ValueError("Invalid ICL method.")
        
        del self.metric_matrix_train, self.metric_matrix_test
        return idx_mat
    
    def get_knn_fixed_data_idx(self):
        idx = list(range(self.k))
        selected_idx = []
        for test_idx in tqdm(range(self.n_test),desc=f"Calculating now"):
            selected_idx.append(idx)
        return selected_idx
    
    def get_diversity_data_idx(self):
        print(f"subset size: {self.subset_size}")
        all_indices = list(range(self.n_train))  
        initial_sample = random.choice(all_indices)
        # print(f"initial_sample: {initial_sample}")
        annotated_set = [initial_sample]
        unannotated_set = [i for i in all_indices if i != initial_sample]
        for _ in range(self.subset_size - 1):
            min_metric_sum = float('inf')
            candidate_index = -1
            
            for sample_index in unannotated_set:
                # if class_sample_count[labels[sample_index]] < samples_per_class:
                metric_sum = np.sum(self.metric_matrix_train[sample_index, annotated_set])
                if metric_sum < min_metric_sum:
                    min_metric_sum = metric_sum
                    candidate_index = sample_index
            if candidate_index == -1:
                raise ValueError("Not enough samples to meet the requirement for each class.")
            
            annotated_set.append(candidate_index)
            unannotated_set.remove(candidate_index)
        assert len(annotated_set) == self.subset_size, f"{len(annotated_set)}, {len(unannotated_set)}"

        selected_idx = []
        self.train_embs = self.train_embs[annotated_set]

        for idx, test_emb in enumerate(self.test_embs):
            selected_idx.append([])
            if self.dp_choice =="knn":
                score_list = self.metric_calculate(self.train_embs, test_emb, self.metric)
                score_list = score_list.reshape(-1)
                sorted_indices = np.argsort(-score_list)
                sorted_indices = [annotated_set[i] for i in sorted_indices] 
            elif self.dp_choice =="random":
                sorted_indices = annotated_set.copy()
                random.shuffle(sorted_indices)
            else:
                raise ValueError("No method for dp_choice when diverse")
            selected_idx[idx] = [i for i in sorted_indices[:self.k]]
        return selected_idx

    def get_knn_s3_data_idx(self):
        subset_size = self.subset_size
        print(f"subset_size: {subset_size}")
        selected_idx = []
        
        train_similarity_matrix = self.metric_matrix_train.copy()
        np.fill_diagonal(train_similarity_matrix, 1.0)
        train_similarity_matrix += 1
        
        for test_idx in tqdm(range(self.n_test),desc=f"Calculating now"):
            test_similarity_vector = self.metric_matrix_test[test_idx]  # shape: (n_train,)
            test_similarity_vector = test_similarity_vector + 1
            train_scores = []
            for train_idx in range(self.n_train):
                train_similarity_vector = train_similarity_matrix[train_idx]  # shape: (n_train,)
                
                max_vector = np.maximum(train_similarity_vector, test_similarity_vector)
                
                score = np.sum(max_vector)
                train_scores.append(score)
            
            train_scores = np.array(train_scores)
            lowest_score_indices = np.argsort(train_scores)[:subset_size]
            
            subset_similarity_matrix = train_similarity_matrix[np.ix_(lowest_score_indices, lowest_score_indices)]
            
            final_set = []
            max_vector = np.full(subset_size, -1.0)  
            remaining_indices = list(range(subset_size))  
            
            for step in range(self.k):
                best_score = -float('inf')
                best_idx = -1
                
                for candidate_idx in remaining_indices:
                    candidate_vector = subset_similarity_matrix[candidate_idx]
                    
                    combined_vector = np.maximum(candidate_vector, max_vector)
                    score = np.sum(combined_vector)
                    
                    if score > best_score:
                        best_score = score
                        best_idx = candidate_idx
                
                final_set.append(best_idx)
                remaining_indices.remove(best_idx)
                
                best_vector = subset_similarity_matrix[best_idx]
                max_vector = np.maximum(max_vector, best_vector)
            
            sorted_indices = [lowest_score_indices[i] for i in final_set]    
            selected_idx.append([i for i in sorted_indices])
    
        return selected_idx

    def get_diversity_document_data_idx(self):

        n_clusters = self.n_train // 5
        # kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        kmeans = KMeans(n_clusters=n_clusters, random_state=self.shuffle_seed)
        cluster_labels = kmeans.fit_predict(self.train_embs)
        
        train_similarity_matrix = self.metric_matrix_train.copy()
        np.fill_diagonal(train_similarity_matrix, 1.0)  
        train_similarity_matrix += 1
        sample_scores = np.sum(train_similarity_matrix, axis=1) / self.n_train
        
        S = []  
        remaining_samples = list(range(self.n_train))
        
        cluster_sums = np.zeros(n_clusters)  
        
        def calculate_total_cluster_score():
            return np.sum(np.sqrt(cluster_sums))
        
        for step in range(self.k):
            best_total_score = -float('inf')
            best_sample_idx = -1
            
            current_total_score = calculate_total_cluster_score()
            
            for sample_idx in remaining_samples:
                sample_cluster = cluster_labels[sample_idx]
                
                old_cluster_sum = cluster_sums[sample_cluster]
                new_cluster_sum = old_cluster_sum + sample_scores[sample_idx]
                
                cluster_score_diff = np.sqrt(new_cluster_sum) - np.sqrt(old_cluster_sum)
                
                total_score = (sample_scores[sample_idx] * self.n_train) + (cluster_score_diff * 6)
                
                if total_score > best_total_score:
                    best_total_score = total_score
                    best_sample_idx = sample_idx
            
            S.append(best_sample_idx)
            remaining_samples.remove(best_sample_idx)
            best_cluster = cluster_labels[best_sample_idx]
            cluster_sums[best_cluster] += sample_scores[best_sample_idx]
        
        selected_idx = []
        for test_idx in range(self.n_test):
            selected_idx.append(S.copy())
        
        return selected_idx

    def get_diversity_document_query_data_idx(self):
        
        n_clusters = self.n_train // 5
        # kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        kmeans = KMeans(n_clusters=n_clusters, random_state=self.shuffle_seed)
        cluster_labels = kmeans.fit_predict(self.train_embs)
        
        selected_idx = []
        test_scores = self.metric_matrix_test.copy()
        test_scores += 1
        train_similarity_matrix = self.metric_matrix_train.copy()
        np.fill_diagonal(train_similarity_matrix, 1.0)  
        train_similarity_matrix += 1
        
        for test_idx in tqdm(range(self.n_test), desc="Calculating now"):
            test_score = test_scores[test_idx]
        
            sample_scores = np.sum(train_similarity_matrix, axis=1) / self.n_train
            sample_scores = (sample_scores + test_score) * 0.5

            S = []  
            remaining_samples = list(range(self.n_train))
            
            cluster_sums = np.zeros(n_clusters)  
            
            def calculate_total_cluster_score():
                return np.sum(np.sqrt(cluster_sums))
            
            for step in range(self.k):
                best_total_score = -float('inf')
                best_sample_idx = -1
                
                for sample_idx in remaining_samples:
                    sample_cluster = cluster_labels[sample_idx]
                    
                    old_cluster_sum = cluster_sums[sample_cluster]
                    new_cluster_sum = old_cluster_sum + sample_scores[sample_idx]
                    
                    cluster_score_diff = np.sqrt(new_cluster_sum) - np.sqrt(old_cluster_sum)
                    
                    total_score = (sample_scores[sample_idx] * self.n_train) + (cluster_score_diff * 6)
                    
                    if total_score > best_total_score:
                        best_total_score = total_score
                        best_sample_idx = sample_idx
                
                S.append(best_sample_idx)
                remaining_samples.remove(best_sample_idx)
                best_cluster = cluster_labels[best_sample_idx]
                cluster_sums[best_cluster] += sample_scores[best_sample_idx]
            
            selected_idx.append(S.copy())
        
        return selected_idx

    def get_greedy_diversity_data_idx(self):
        all_indices = list(range(self.n_train))  
        initial_sample = np.argmin(np.sum(self.metric_matrix_train, axis=1))
        annotated_set = [initial_sample]
        unannotated_set = [i for i in all_indices if i != initial_sample]
        for _ in range(self.subset_size - 1):
            min_metric_sum = float('inf')
            candidate_index = -1
            
            for sample_index in unannotated_set:
                # if class_sample_count[labels[sample_index]] < samples_per_class:
                metric_sum = np.sum(self.metric_matrix_train[sample_index, annotated_set])
                if metric_sum < min_metric_sum:
                    min_metric_sum = metric_sum
                    candidate_index = sample_index
            if candidate_index == -1:
                raise ValueError("Not enough samples to meet the requirement for each class.")
            
            annotated_set.append(candidate_index)
            unannotated_set.remove(candidate_index)
        assert len(annotated_set) == self.subset_size, f"{len(annotated_set)}, {len(unannotated_set)}"

        selected_idx = []
        self.train_embs = self.train_embs[annotated_set]

        for idx, test_emb in enumerate(self.test_embs):
            selected_idx.append([])
            if self.dp_choice =="knn":
                score_list = self.metric_calculate(self.train_embs, test_emb, self.metric)
                score_list = score_list.reshape(-1)
                sorted_indices = np.argsort(-score_list)
                sorted_indices = [annotated_set[i] for i in sorted_indices] 
            elif self.dp_choice =="random":
                sorted_indices = annotated_set.copy()
                random.shuffle(sorted_indices)
            else:
                raise ValueError("No method for dp_choice when diverse")
            selected_idx[idx] = [i for i in sorted_indices[:self.k]]
        return selected_idx
    
    def get_knn_diversity_data_idx(self):
        selected_idx = []
        
        knn_scores_all = self.metric_matrix_test
        
        for i in tqdm(range(self.n_test), desc="Calculating KNN+Diversity scores"):
            annotated_set = []
            
            is_annotated = np.zeros(self.n_train, dtype=bool)
            
            sum_of_diversities = np.zeros(self.n_train)
            
            current_knn_scores = knn_scores_all[i]
            
            nearest_index = np.argmax(current_knn_scores)
            annotated_set.append(int(nearest_index))
            is_annotated[nearest_index] = True
            
            if self.k > 1:
                sum_of_diversities += self.metric_matrix_train[:, nearest_index]

                for _ in range(self.k - 1):
                    unann_idx = np.flatnonzero(~is_annotated)

                    diversity_scores_subset = len(annotated_set) - sum_of_diversities[unann_idx]

                    knn_scores_subset = current_knn_scores[unann_idx]

                    knn_scores_norm = (knn_scores_subset / np.max(knn_scores_subset)).reshape(-1)
                    diversity_scores_norm = (diversity_scores_subset / np.max(diversity_scores_subset)).reshape(-1)

                    total_scores_subset = (1 - self.alpha) * knn_scores_norm + self.alpha * diversity_scores_norm

                    idx_in_subset = int(np.argmax(total_scores_subset))
                    candidate_index = int(unann_idx[idx_in_subset])

                    annotated_set.append(candidate_index)
                    is_annotated[candidate_index] = True
                    sum_of_diversities += self.metric_matrix_train[:, candidate_index]

            selected_idx.append(annotated_set)
            
        return selected_idx

    def old_get_knn_diversity_data_idx(self):        
        selected_idx = []
                
        for i in tqdm(range(self.n_test), desc="Calculating KNN scores"):
            all_indices = list(range(self.n_train))
            nearest_index = np.argmax(self.metric_matrix_test[i])
            annotated_set = [nearest_index]
            unannotated_set = [i for i in all_indices if i != nearest_index]

            for _ in range(self.k - 1):
                knn_scores = self.metric_matrix_test[i][unannotated_set]
                diversity_scores = len(annotated_set) - np.sum(self.metric_matrix_train[unannotated_set][:, annotated_set], axis=1)
                knn_scores_norm = (knn_scores / np.max(knn_scores)).reshape(-1)
                diversity_scores_norm = (diversity_scores / np.max(diversity_scores)).reshape(-1)
                total_scores = (1-self.alpha) * knn_scores_norm + self.alpha * diversity_scores_norm
                candidate_index = unannotated_set[np.argmax(total_scores)]
                annotated_set.append(candidate_index)
                unannotated_set.remove(candidate_index)
            selected_idx.append(annotated_set)        
        return selected_idx
        
    def get_random_data_idx(self):

        selected_idx = []
        for i in range(self.n_test):
            selected_id = list(range(self.n_train))
            random.shuffle(selected_id)
            selected_id = selected_id[:self.k]
            if self.random_shuffled:
                random.shuffle(selected_id)
            selected_idx.append(selected_id)
        return selected_idx

    def get_knn_data_idx(self):
        selected_idx = []

        selected_idx = np.argsort(-self.metric_matrix_test, axis=1)[:,:self.k]
        return selected_idx

    def zero_shot(self):
        idx_mat = []
        for i in range(self.n_test):
            idx_mat.append([])
        return idx_mat

    def get_confidence_strata_data_idx(self):
        confidence = np.load(f"./data/{self.train_dataset}/{self.model_name}/probs.npy")
        answer_keys = []
        
        with open(f"./data/{self.train_dataset}/{self.train_dataset}_train.jsonl", 'r', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                if(self.train_dataset == 'arc-easy' or self.train_dataset == 'arc-challenge' or self.train_dataset == 'hellaswag' or self.train_dataset == 'openbook_qa'):
                    if(len(data['choices']['label']) == 4):
                        answer_keys.append(data['answerKey'])
                elif(self.train_dataset == 'commonsense_qa'):
                    if(len(data['choices']['label']) == 5):
                        answer_keys.append(data['answerKey'])
                elif(self.train_dataset == 'sociali_qa'):
                    if(len(data['choices']['label']) == 3):
                        answer_keys.append(data['answerKey'])
                elif(self.train_dataset == 'amazon_polarity' or self.train_dataset == "yelp_polarity" or self.train_dataset == "glue-sst2" or self.train_dataset == "customer_reviews" or self.train_dataset == "imdb"):
                    if(len(data['choices']['label']) == 2):
                        answer_keys.append(data['answerKey'])
                else:
                    raise ValueError("No answer key for this dataset")
        answer_keys = np.array(answer_keys)
        row_indices = np.arange(len(answer_keys))
        confidence = softmax(confidence, axis=1)
        confidence = confidence[row_indices, answer_keys]  # shape: (1000,)
        confidence_bins = defaultdict(list)
        for idx, conf in enumerate(confidence):
            bin_index = min(int(conf * 10), 9)
            confidence_bins[bin_index].append(idx)
        confidence_bins[9] = []
        annotated_set = []
        bin_counts = sorted([(bin_idx, len(indices)) for bin_idx, indices in confidence_bins.items()], 
                        key=lambda x: x[1])  
        remaining_count = self.subset_size
        for i in range(9):
            bin_idx, count = bin_counts[i] 
            samples_needed = remaining_count // (9-i)  
            
            if samples_needed >= count:
                for index in confidence_bins[bin_idx]:
                    annotated_set.append(index)
                    remaining_count -= 1
                # annotated_set.extend(confidence_bins[bin_idx])
                # remaining_size -= count
            else:
                random.shuffle(confidence_bins[bin_idx])

                for index in confidence_bins[bin_idx]:

                    annotated_set.append(index)
                    samples_needed -= 1
                    remaining_count -= 1

                    if samples_needed == 0:
                        break
                # selected = random.sample(confidence_bins[bin_idx], samples_needed)
                # annotated_set.extend(selected)
                # remaining_size -= samples_needed
                
            if remaining_count <= 0:
                break

        selected_idx = []
        self.train_embs = self.train_embs[annotated_set]
        for idx, emb in enumerate(self.test_embs):
            selected_idx.append([])
            if self.dp_choice == "knn":
                score_list = self.metric_calculate(self.train_embs, emb, self.metric)
                score_list = score_list.reshape(-1)
                sorted_indices = np.argsort(-score_list)
                sorted_indices = [annotated_set[i] for i in sorted_indices]
            elif self.dp_choice == "random":
                sorted_indices = annotated_set.copy()
                random.shuffle(sorted_indices)
            else:
                raise ValueError("No method for dp_choice when diverse")
            selected_idx[idx] = [i for i in sorted_indices[:self.k]]
        # print(len(selected_idx))
        # print(len(selected_idx[0]))
        return selected_idx

    def metric_calculate(self,left_embs,right_embs,metric):
        left_embs = np.asarray(left_embs)
        right_embs = np.asarray(right_embs)
        if left_embs.ndim == 1:
            left_embs = left_embs.reshape(1,-1)
        elif right_embs.ndim == 1:
            right_embs = right_embs.reshape(1,-1)
        if( self.metric == "cosine_similarity"):
            dot_product = np.dot(left_embs, right_embs.T)

            right_embs_norm = np.linalg.norm(right_embs,axis=1)
            left_embs_norm = np.linalg.norm(left_embs,axis=1)
            
            metric_distances = dot_product / (left_embs_norm[:,np.newaxis] * right_embs_norm[np.newaxis,:])
        return np.asarray(metric_distances)

    def get_voke_k_data_idx(self):        
        selected_indices, votes, selected_times = self.fast_votek(self.train_embs,self.subset_size//10,k=150,vote_file=None)
                
        ids = [selected_indices for i in range(len(self.train_embs))]
        
        confidence = np.load(f"./data/{self.train_dataset}/{self.model_name}/probs.npy")
        answer_keys = []
        
        with open(f"./data/{self.train_dataset}/{self.train_dataset}_train.jsonl", 'r', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                if(self.train_dataset == 'arc-easy' or self.train_dataset == 'arc-challenge' or self.train_dataset == 'hellaswag' or self.train_dataset == 'openbook_qa'):
                    if(len(data['choices']['label']) == 4):
                        answer_keys.append(data['answerKey'])
                elif(self.train_dataset == 'commonsense_qa'):
                    if(len(data['choices']['label']) == 5):
                        answer_keys.append(data['answerKey'])
                elif(self.train_dataset == 'sociali_qa'):
                    if(len(data['choices']['label']) == 3):
                        answer_keys.append(data['answerKey'])
                elif(self.train_dataset == 'amazon_polarity' or self.train_dataset == "yelp_polarity" or self.train_dataset == "glue-sst2" or self.train_dataset == "customer_reviews" or self.train_dataset == "imdb"):
                    if(len(data['choices']['label']) == 2):
                        answer_keys.append(data['answerKey'])
                else:
                    raise ValueError("No answer key for this dataset")
        answer_keys = np.array(answer_keys)
        row_indices = np.arange(len(answer_keys))
        confidence = softmax(confidence, axis=1)
        confidence = confidence[row_indices, answer_keys]  # shape: (1000,)
        
        # Step 1: Sort indices by their confidence values in descending order
        sorted_indices = sorted(range(len(confidence)), key=lambda i: confidence[i], reverse=True)
        
        # Step 2: Divide sorted indices into M lists
        chunk_size = len(confidence) // self.subset_size
        divided_indices = [sorted_indices[i * chunk_size:(i + 1) * chunk_size] for i in range(self.subset_size)]
        
        # Step 3: Remove selected indices from each list
        filtered_indices = []
        for indices in divided_indices:
            filtered_indices.append([index for index in indices if index not in selected_indices])
        
        # Step 4: Remove the first M/10 lists
        num_to_remove = self.subset_size // 10
        filtered_indices = filtered_indices[num_to_remove:]
        vote_stat = defaultdict(list)
        for i in range(self.n_train):
            cur_emb = self.train_embs[i].reshape(1, -1)
            cur_scores = np.sum(cosine_similarity(self.train_embs, cur_emb), axis=1)
            sorted_indices = np.argsort(cur_scores).tolist()[-self.k-1:-1]
            for idx in sorted_indices:
                if idx != i:
                    vote_stat[idx].append(i)
        votes = sorted(vote_stat.items(), key=lambda x: len(x[1]), reverse=True)

        cur_scores = defaultdict(int)
        for idx, candidates in votes:
            if idx in selected_indices:
                cur_scores[idx] = -100
                continue
            for one_support in candidates:
                if one_support not in selected_indices:
                    cur_scores[idx] += 10 ** (-selected_times[one_support])
        
        cur_scores = defaultdict(int)
        for idx, candidates in votes:
            if idx in selected_indices:
                cur_scores[idx] = -100
                continue
            for one_support in candidates:
                if one_support not in selected_indices:
                    cur_scores[idx] += 10 ** (-selected_times[one_support])
        max_indices = []
        for indices in filtered_indices:
            if indices:  
                max_index = max(indices, key=lambda idx: cur_scores[idx])
                max_indices.append(max_index)

        annotated_set = selected_indices + max_indices
        unannotated_set = [idx for idx in range(self.n_train) if idx not in annotated_set]
            
        selected_idx = []
        self.train_embs = self.train_embs[annotated_set]

        for idx,emb in enumerate(self.test_embs):
            selected_idx.append([])
            if self.dp_choice =="knn":
                score_list = self.metric_calculate(self.train_embs, emb, self.metric)
                score_list = score_list.reshape(-1)
                sorted_indices = np.argsort(-score_list)
                sorted_indices = [annotated_set[i] for i in sorted_indices] 
            elif self.dp_choice =="random":
                sorted_indices = annotated_set.copy()
                random.shuffle(sorted_indices)
            else:
                raise ValueError("No method for dp_choice when diverse")
            selected_idx[idx] = [i for i in sorted_indices[:self.k]]
        return selected_idx
                    
    def fast_votek(self,embeddings,select_num,k,vote_file=None):
        n = len(embeddings)
        if vote_file is not None and os.path.isfile(vote_file):
            with open(vote_file) as f:
                vote_stat = json.load(f)
        else:
            bar = tqdm(range(n),desc=f'voting')
            vote_stat = defaultdict(list)
            for i in range(n):
                cur_emb = embeddings[i].reshape(1, -1)
                cur_scores = np.sum(cosine_similarity(embeddings, cur_emb), axis=1)
                sorted_indices = np.argsort(cur_scores).tolist()[-k-1:-1]
                for idx in sorted_indices:
                    if idx!=i:
                        vote_stat[idx].append(i)
                bar.update(1)
            if vote_file is not None:
                with open(vote_file,'w') as f:
                    json.dump(vote_stat,f)
        votes = sorted(vote_stat.items(),key=lambda x:len(x[1]),reverse=True)
        selected_indices = []
        selected_times = defaultdict(int)
        while len(selected_indices)<select_num:
            cur_scores = defaultdict(int)
            for idx,candidates in votes:
                if idx in selected_indices:
                    cur_scores[idx] = -100
                    continue
                for one_support in candidates:
                    if not one_support in selected_indices:
                        cur_scores[idx] += 10 ** (-selected_times[one_support])
            cur_selected_idx = max(cur_scores.items(),key=lambda x:x[1])[0]
            selected_indices.append(int(cur_selected_idx))
            for idx_support in vote_stat[cur_selected_idx]:
                selected_times[idx_support] += 1
        return selected_indices, votes, selected_times
    
    def get_better_diversity_data_idx(self):
        all_indices = list(range(self.n_train))  
        initial_sample = random.choice(all_indices)
        # print(f"initial_sample: {initial_sample}")
        annotated_set = [initial_sample]
        unannotated_set = [i for i in all_indices if i != initial_sample]
        for _ in range(self.subset_size - 1):
            min_metric_sum = float('inf')
            candidate_index = -1
            
            for sample_index in unannotated_set:
                # if class_sample_count[labels[sample_index]] < samples_per_class:
                metric_sum = np.sum(self.metric_matrix_train[sample_index, annotated_set])
                if metric_sum < min_metric_sum:
                    min_metric_sum = metric_sum
                    candidate_index = sample_index
            if candidate_index == -1:
                raise ValueError("Not enough samples to meet the requirement for each class.")
            
            annotated_set.append(candidate_index)
            unannotated_set.remove(candidate_index)
        assert len(annotated_set) == self.subset_size, f"{len(annotated_set)}, {len(unannotated_set)}"

        selected_idx = []
        self.train_embs = self.train_embs[annotated_set]

        for idx, test_emb in enumerate(self.test_embs):
            selected_idx.append([])
            if self.dp_choice =="knn":
                score_list = self.metric_calculate(self.train_embs, test_emb, self.metric)
                score_list = score_list.reshape(-1)
                sorted_indices = np.argsort(-score_list)
                sorted_indices = [annotated_set[i] for i in sorted_indices] 
            elif self.dp_choice =="random":
                sorted_indices = annotated_set.copy()
                random.shuffle(sorted_indices)
            else:
                raise ValueError("No method for dp_choice when diverse")
            selected_idx[idx] = [i for i in sorted_indices[:self.k]]
        return selected_idx
    
    def get_k_means_data_idx(self):
        from sklearn.cluster import KMeans
        
        kmeans = KMeans(n_clusters=self.k, random_state=0).fit(self.train_embs)
        
        cluster_centers = kmeans.cluster_centers_
        
        cluster_labels = kmeans.labels_
        
        closest_points = []
        for i in range(self.k):
            cluster_indices = np.where(cluster_labels == i)[0]
            
            if len(cluster_indices) == 0:
                continue
                
            distances = np.array([
                np.linalg.norm(self.train_embs[idx] - cluster_centers[i]) 
                for idx in cluster_indices
            ])
            
            closest_point_idx = cluster_indices[np.argmin(distances)]
            closest_points.append(closest_point_idx)
        
        if len(closest_points) < self.k:
            remaining = self.k - len(closest_points)
            all_indices = set(range(self.n_train))
            used_indices = set(closest_points)
            available_indices = list(all_indices - used_indices)
            
            if len(available_indices) >= remaining:
                additional_points = random.sample(available_indices, remaining)
                closest_points.extend(additional_points)
        
        selected_idx = []
        for _ in range(len(self.test_embs)):
            selected_idx.append(closest_points.copy())
            
        return selected_idx
